{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本预处理\n",
    ":label:`sec_text_preprocessing`\n",
    "\n",
    "对于序列数据处理问题，我们在 :numref:`sec_sequence`中\n",
    "评估了所需的统计工具和预测时面临的挑战。\n",
    "这样的数据存在许多种形式，文本是最常见例子之一。\n",
    "例如，一篇文章可以被简单地看作是一串单词序列，甚至是一串字符序列。\n",
    "本节中，我们将解析文本的常见预处理步骤。\n",
    "这些步骤通常包括：\n",
    "\n",
    "1. 将文本作为字符串加载到内存中。\n",
    "1. 将字符串拆分为词元（如单词和字符）。\n",
    "1. 建立一个词表，将拆分的词元映射到数字索引。\n",
    "1. 将文本转换为数字索引序列，方便模型操作。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据集\n",
    "\n",
    "首先，我们从H.G.Well的[时光机器](https://www.gutenberg.org/ebooks/35)中加载文本。\n",
    "这是一个相当小的语料库，只有30000多个单词，但足够我们小试牛刀，\n",
    "而现实中的文档集合可能会包含数十亿个单词。\n",
    "下面的函数将数据集读取到由多条文本行组成的列表中，其中每条文本行都是一个字符串。\n",
    "为简单起见，我们在这里忽略了标点符号和字母大写。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public String[] readTimeMachine() throws IOException {\n",
    "    URL url = new URL(\"http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt\");\n",
    "    String[] lines;\n",
    "    try (BufferedReader in = new BufferedReader(new InputStreamReader(url.openStream()))) {\n",
    "        lines = in.lines().toArray(String[]::new);\n",
    "    }\n",
    "\n",
    "    for (int i = 0; i < lines.length; i++) {\n",
    "        lines[i] = lines[i].replaceAll(\"[^A-Za-z]+\", \" \").strip().toLowerCase();\n",
    "    }\n",
    "    return lines;\n",
    "}\n",
    "\n",
    "String[] lines = readTimeMachine();\n",
    "System.out.println(\"# text lines: \" + lines.length);\n",
    "System.out.println(lines[0]);\n",
    "System.out.println(lines[10]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 词元化\n",
    "\n",
    "下面的`tokenize`函数将文本行列表（`lines`）作为输入，\n",
    "列表中的每个元素是一个文本序列（如一条文本行）。\n",
    "每个文本序列又被拆分成一个词元列表，*词元*（token）是文本的基本单位。\n",
    "最后，返回一个由词元列表组成的列表，其中的每个词元都是一个字符串（string）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public String[][] tokenize(String[] lines, String token) throws Exception {\n",
    "    // 将文本行拆分为单词或字符标记\n",
    "    String[][] output = new String[lines.length][];\n",
    "    if (token == \"word\") {\n",
    "        for (int i = 0; i < output.length; i++) {\n",
    "            output[i] = lines[i].split(\" \");\n",
    "        }\n",
    "    } else if (token == \"char\") {\n",
    "        for (int i = 0; i < output.length; i++) {\n",
    "            output[i] = lines[i].split(\"\");\n",
    "        }\n",
    "    } else {\n",
    "        throw new Exception(\"ERROR: unknown token type: \" + token);\n",
    "    }\n",
    "    return output; \n",
    "}\n",
    "String[][] tokens = tokenize(lines, \"word\");\n",
    "for (int i = 0; i < 11; i++) {\n",
    "    System.out.println(Arrays.toString(tokens[i]));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 词汇表\n",
    "\n",
    "词元的类型是字符串，而模型需要的输入是数字，因此这种类型不方便模型使用。\n",
    "现在，让我们构建一个字典，通常也叫做*词汇表*（vocabulary），\n",
    "用来将字符串类型的词元映射到从$0$开始的数字索引中。\n",
    "我们先将训练集中的所有文档合并在一起，对它们的唯一词元进行统计，\n",
    "得到的统计结果称之为*语料*（corpus）。\n",
    "然后根据每个唯一词元的出现频率，为其分配一个数字索引。\n",
    "很少出现的词元通常被移除，这可以降低复杂性。\n",
    "另外，语料库中不存在或已删除的任何词元都将映射到一个特定的未知词元“&lt;unk&gt;”。\n",
    "我们可以选择增加一个列表，用于保存那些被保留的词元，\n",
    "例如：填充词元（“&lt;pad&gt;”）；\n",
    "序列开始词元（“&lt;bos&gt;”）；\n",
    "序列结束词元（“&lt;eos&gt;”）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class Vocab {\n",
    "    public int unk;\n",
    "    public List<Map.Entry<String, Integer>> tokenFreqs;\n",
    "    public List<String> idxToToken;\n",
    "    public HashMap<String, Integer> tokenToIdx;\n",
    "\n",
    "    public Vocab(String[][] tokens, int minFreq, String[] reservedTokens) {\n",
    "        // 按频率排序\n",
    "        LinkedHashMap<String, Integer> counter = countCorpus2D(tokens);\n",
    "        this.tokenFreqs = new ArrayList<Map.Entry<String, Integer>>(counter.entrySet()); \n",
    "        Collections.sort(tokenFreqs, \n",
    "            new Comparator<Map.Entry<String, Integer>>() { \n",
    "                public int compare(Map.Entry<String, Integer> o1, Map.Entry<String, Integer> o2) { \n",
    "                    return (o2.getValue()).compareTo(o1.getValue()); \n",
    "                }\n",
    "            });\n",
    "        \n",
    "        // 未知标记的索引为0\n",
    "        this.unk = 0;\n",
    "        List<String> uniqTokens = new ArrayList<>();\n",
    "        uniqTokens.add(\"<unk>\");\n",
    "        Collections.addAll(uniqTokens, reservedTokens);\n",
    "        for (Map.Entry<String, Integer> entry : tokenFreqs) {\n",
    "            if (entry.getValue() >= minFreq && !uniqTokens.contains(entry.getKey())) {\n",
    "                uniqTokens.add(entry.getKey());\n",
    "            }\n",
    "        }\n",
    "        \n",
    "        this.idxToToken = new ArrayList<>();\n",
    "        this.tokenToIdx = new HashMap<>();\n",
    "        for (String token : uniqTokens) {\n",
    "            this.idxToToken.add(token);\n",
    "            this.tokenToIdx.put(token, this.idxToToken.size()-1);\n",
    "        }\n",
    "    }\n",
    "    \n",
    "    public int length() {\n",
    "        return this.idxToToken.size();\n",
    "    }\n",
    "    \n",
    "    public Integer[] getIdxs(String[] tokens) {\n",
    "        List<Integer> idxs = new ArrayList<>();\n",
    "        for (String token : tokens) {\n",
    "            idxs.add(getIdx(token));\n",
    "        }\n",
    "        return idxs.toArray(new Integer[0]);\n",
    "        \n",
    "    }\n",
    "    \n",
    "    public Integer getIdx(String token) {\n",
    "        return this.tokenToIdx.getOrDefault(token, this.unk);\n",
    "    }\n",
    "    \n",
    "    \n",
    "}\n",
    "\n",
    "public LinkedHashMap<String, Integer> countCorpus(String[] tokens) {\n",
    "    /* 计算token频率 */\n",
    "    LinkedHashMap<String, Integer> counter = new LinkedHashMap<>();\n",
    "    if (tokens.length != 0) {\n",
    "        for (String token : tokens) {\n",
    "            counter.put(token, counter.getOrDefault(token, 0)+1);\n",
    "        }\n",
    "    }\n",
    "    return counter;\n",
    "}\n",
    "\n",
    "public LinkedHashMap<String, Integer> countCorpus2D(String[][] tokens) {\n",
    "    /* 将token列表展平为token列表*/\n",
    "    List<String> allTokens = new ArrayList<String>();\n",
    "    for (int i = 0; i < tokens.length; i++) {\n",
    "        for (int j = 0; j < tokens[i].length; j++) {\n",
    "             if (tokens[i][j] != \"\") {\n",
    "                allTokens.add(tokens[i][j]);\n",
    "             }\n",
    "        }\n",
    "    }\n",
    "    return countCorpus(allTokens.toArray(new String[0]));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们使用时间机器数据集作为语料库构建了一个词汇表。\n",
    "然后，我们打印前几个频繁标记及其索引。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Vocab vocab = new Vocab(tokens, 0, new String[0]);\n",
    "for (int i = 0; i < 10; i++) {\n",
    "    String token = vocab.idxToToken.get(i);\n",
    "    System.out.print(\"(\" + token + \", \" + vocab.tokenToIdx.get(token) + \") \");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们可以将每一行文本转换为一个数字索引列表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for (int i : new int[] {0,10}) {\n",
    "    System.out.println(\"Words:\" + Arrays.toString(tokens[i]));\n",
    "    System.out.println(\"Indices:\" + Arrays.toString(vocab.getIdxs(tokens[i])));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 把所有的东西放在一起\n",
    "\n",
    "使用上述函数，我们将所有内容打包到`loadCorpusTimeMachine`函数中，\n",
    "该函数返回`corpus`，一个标记索引列表，以及`vocab`，时间机器语料库的词汇表。\n",
    "我们在这里做的修改是：\n",
    "一） 我们将文本标记为字符，而不是单词，以简化后面章节中的训练；\n",
    "二）`corpus` 是一个单一的列表，而不是标记列表列表，因为时间机器数据集中的每一行文本不一定是一个句子或段落。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "public Pair<List<Integer>, Vocab> loadCorpusTimeMachine(int maxTokens) throws IOException, Exception {\n",
    "    /* 返回时间机器数据集的令牌索引和词汇表。 */\n",
    "    String[] lines = readTimeMachine();\n",
    "    String[][] tokens = tokenize(lines, \"char\");\n",
    "    Vocab vocab = new Vocab(tokens, 0, new String[0]);\n",
    "    // 因为时间机器数据集中的每个文本行不一定是\n",
    "    // 句子或段落，将所有文本行展平为一个列表\n",
    "    List<Integer> corpus = new ArrayList<>();\n",
    "    for (int i = 0; i < tokens.length; i++) {\n",
    "        for (int j = 0; j < tokens[i].length; j++) {\n",
    "            if (tokens[i][j] != \"\") {\n",
    "                corpus.add(vocab.getIdx(tokens[i][j]));\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "    if (maxTokens > 0) {\n",
    "        corpus = corpus.subList(0, maxTokens);\n",
    "    }\n",
    "    return new Pair(corpus, vocab);\n",
    "}\n",
    "\n",
    "Pair<List<Integer>, Vocab> corpusVocabPair = loadCorpusTimeMachine(-1);\n",
    "List<Integer> corpus = corpusVocabPair.getKey();\n",
    "Vocab vocab = corpusVocabPair.getValue();\n",
    "\n",
    "System.out.println(corpus.size());\n",
    "System.out.println(vocab.length());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 文本是序列数据的一种最常见的形式之一。\n",
    "* 为了对文本进行预处理，我们通常将文本拆分为词元，构建词表将词元字符串映射为数字索引，并将文本数据转换为词元索引以供模型操作。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 词元化是一个关键的预处理步骤，它因语言而异。尝试找到另外三种常用的词元化文本的方法。\n",
    "1. 在本节的实验中，将文本词元为单词和更改`Vocab`实例的`minFreq`参数。这对词汇表大小有何影响？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
